Decision Trees & Random forest
Tree-Based Methods
In this section, we describe tree-based methods for regression. These involve stratifying or segmenting the predictor space into a number of simple regions. In order to make a prediction for a given observation, we typically use the mean or the mode of the observations in the region to which it belongs. Since the set of splitting rules used to segment the predictor space can be summarized in a tree, these types of approaches are known as decision tree methods.
Tree-based methods are simple and useful for interpretation. However, they typically are not competitive with the best supervised learning approaches in terms of prediction accuracy. Hence, in this chapter we also introduce the random forests method that involves producing multiple trees which are then combined to yield a single consensus prediction. We will see that combining a large number of trees can often result in dramatic improvements in prediction accuracy, at the expense of some loss in interpretation.
Decision trees can be applied to both regression and classification problems.
Let’s look at the following example on “Baseball salary” data. How would you stratify this?
salary is color coded from low (blue,green) to high (yellow, red).
We use the Hitters data set to predict a baseball player’s Salary based on Years (the number of years that he has played in the major leagues) and Hits (the number of hits that he made in the previous year).
Overall, the tree stratifies or segments the players into three regions of predictor space: \(R_1=\{\mathrm{X} \mid\) Years \(<4.5\}\), \(R_2=\{\mathrm{X} \mid\) Years \(>=4.5\), Hits \(<117.5\}\), and \(R_3=\{\mathrm{X} \mid\) Years \(>=4.5\), Hits \(>=117.5\}\)
Prediction via Stratification of the Feature Space
We now discuss the process of building a regression tree. Roughly speaking, there are two steps.
We divide the predictor space - that is, the set of possible values for \(X_1, X_2, \ldots, X_p\) - into \(J\) distinct and non-overlapping regions, \(R_1, R_2, \ldots, R_J\) known as terminal nodes or leaves of the tree. Decision trees are typically drawn upside down, in the sense that the leaves are at the bottom of the tree. The points along the tree where the predictor space is split are referred to as internal nodes. We refer to the segments of the trees that connect the nodes as branches.
For every observation that falls into the region \(R_k\), we make the same prediction, which is simply the mean of the response values for the observations in \(R_k\).
How do we construct the regions \(R_1, R_2, \ldots, R_J\) ? In theory, the regions could have any shape. However, we choose to divide the predictor space into high-dimensional rectangles, or boxes, for simplicity and for ease of interpretation of the resulting predictive model. The goal is to find boxes \(R_1, R_2, \ldots, R_J\) that minimize the RSS, given by \[ \sum_{k=1}^J \sum_{i \varepsilon R_k}\left(y_i-\hat{y}_{R_k}\right)^2 \] where \(\hat{y}_{R_k}\) is the mean response for the observations within the \(k\) th box. Unfortunately, it is computationally infeasible to consider every possible partition of the feature space into \(J\) boxes. For this reason, we take a top-down, greedy approach that is known as recursive binary splitting. The approach is top-down because it begins at the top of the tree (at which point all observations belong to a single region) and then successively splits the predictor space; each split is indicated via two new branches further down on the tree. It is greedy because at each step of the tree-building process, the best split is made at that particular step, rather than looking ahead and picking a split that will lead to a better tree in some future step. In order to perform recursive binary splitting, we first select the predictor \(X_j\) and the cut-point \(s\) such that splitting the predictor space into the regions \(\left\{X \mid X_k<s\right\}\) and \(\left\{X \mid X_k \geq s\right\}\) leads to the greatest possible reduction in RSS. That is, we consider all predictors \(X_1, X_2, \ldots, X_p\) and all possible values of the cut-point \(s\) for each one of the predictors, and then choose the predictor and cut-point such that the resulting tree has the lowest RSS.
In greater detail, for any \(k\) and \(s\), we define the pair of half-planes \[ R_1(k, s)=\left\{X \mid X_k<s\right\} \] and \[ R_2(k, s)=\left\{X \mid X_k \geq s\right\} \] and we seek the value of \(k\) and \(s\) that minimize the equation \[ \sum_{i: x_i \varepsilon R_{1(k, s)}}\left(y_i-\hat{y}_{R_1}\right)^2+\sum_{i: x_i \varepsilon R_{2(k, s)}}\left(y_i-\hat{y}_{R_2}\right)^2 \] where \(\hat{y}_{R_1}\) is the mean response for the observations in \(R_1(k, s)\), and \(\hat{y}_{R_2}\) is the mean response for the observations in \(R_2(k, s)\). Finding the values of \(k\) and \(s\) that minimize this equation can be done quite quickly, especially when the number of features \(p\) is not too large. Next, we repeat the process, looking for the best predictor and best cut-point in order to split the data further so as to minimize the RSS within each of the resulting regions. However, this time, instead of splitting the entire predictor space, we split one of the two previously identified regions. We now have three regions. Again, we look to split one of these three regions further, so as to minimize the RSS. The process continues until a stopping criterion is reached; for instance, we may continue until no region contains more than five observations. Once the regions \(R_1, R_2, \ldots, R_J\) have been created, we predict the response for a given test observation using the mean of the training observations in the region to which that observation belongs.
Tree Pruning
The process described above is likely to overfit the data because the resulting tree might be too complex. A smaller tree with fewer splits (that is, fewer regions \(R_1, R_2, \ldots, R_J\) ) might lead to lower variance and better interpretation at the cost of a little bias. One possible alternative to the process described above is to build the tree only so long as the decrease in the RSS due to each split exceeds some (high) threshold. This strategy will result in smaller trees, but is too short-sighted since a seemingly worthless split early on in the tree might be followed by a very good split - that is, a split that leads to a large reduction in RSS later on. Therefore, a better strategy is to grow a very large tree \(T_0\), and then prune it back in order to obtain a subtree. How do we determine the best way to prune the tree? Intuitively, our goal is to select a subtree that leads to the lowest test error rate. Given a subtree, we can estimate its test error using cross-validation or the validation set approach. However, estimating the cross-validation error for every possible subtree would be too cumbersome, since there is an extremely large number of possible subtrees. Instead, we need a way to select a small set of subtrees for consideration.
Cost complexity pruning - also known as weakest link pruning - gives us a way to do just this. Rather than considering every possible subtree, we consider a sequence of trees indexed by a nonnegative tuning parameter \(\alpha\). For each value of \(\alpha\) there corresponds a subtree \(T \subset T_0\) such that
\[ \sum_{m=1}^{|T|} \sum_{i: x_i \varepsilon R_m}\left(y_i-\hat{y}_{R_m}\right)^2+\alpha|T| \]
is as small as possible. Here \(|T|\) indicates the number of terminal nodes of the tree \(T, R_m\) is the rectangle (i.e. the subset of predictor space) corresponding to the \(m\) th terminal node, and \(\hat{y}_{R_m}\) is the predicted response associated with \(R_m\) - that is, the mean of the observations in \(R_m\). The tuning parameter \(\alpha\) controls a trade-off between the subtree’s complexity and its fit to the training data. When \(\alpha=0\), then the subtree \(T\) will simply equal \(T_0\). However, as \(\alpha\) increases, there is a price to pay for having a tree with many terminal nodes, and so the above quantity will tend to be minimized for a smaller subtree. This equation is reminiscent of the lasso regression. It turns out that as we increase \(\alpha\) from zero, branches get pruned from the tree in a nested and predictable fashion, so obtaining the whole sequence of subtrees as a function of \(\alpha\) is easy. We can select a value of \(\alpha\) using a validation set or using cross-validation. We then return to the full data set and obtain the subtree corresponding to \(\alpha\). This process is summarized in the following algorithm:
Building a Regression Tree
Use recursive binary splitting to grow a large tree on the training data, stopping only when each terminal node has fewer than some minimum number of observations.
Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of \(\alpha\).
Use K-fold cross-validation to choose \(\alpha\). That is, divide the training observations into \(K\) folds. For each \(k=1, \ldots, K\) : a). Repeat Steps 1 and 2 on all but the \(k\) th fold of the training data.
b). Evaluate the mean squared prediction error on the data in the left-out \(k\) th fold, as a function of \(\alpha\).
Average the results for each value of \(\alpha\), and pick \(\alpha\) to minimize the average error.
Return the subtree from Step 2 that corresponds to the chosen value of \(\alpha\)
Example: Boston Housing Data: Regression Tree
Here we fit a regression tree to the Boston data set. First, we create a training set, and fit the tree to the training data.
Fitting a model:
##
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "crim" "age"
## Number of terminal nodes: 7
## Residual mean deviance: 10.38 = 2555 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -10.1800 -1.7770 -0.1775 0.0000 1.9230 16.5800
Notice that the output of summary() indicates that only
four of the variables have been used in constructing the tree. In the
context of a regression tree, the deviance is simply the sum of squared
errors for the tree.
We now plot the tree.
The variable lstat measures the percentage of
individuals with lower socioeconomic status. The tree indicates that
lower values of lstat correspond to more expensive houses.
The tree predicts a median house price of $46, 400 for larger homes in
suburbs in which residents have high socioeconomic status.
Additionally you can use rpart package to fit a tree
model as well. With this package you can get nicer data
visualizations.
library(rpart)
library(rattle)
tree.boston2 = rpart(medv~.,Boston,subset = train)
fancyRpartPlot(tree.boston2)## you need 'rattle' to use thislibrary(sparkline)
library(visNetwork)# to make this viz, you need to use the package "rpart"
visTree(tree.boston2) # another vizYou can use the above if you need to get better visualizations.
But we will continue to use the tree package.
Now we use the cv.tree() function to see whether pruning
the tree will improve performance.
In this case, the most complex tree is selected by cross-validation.
However, if we wish to prune the tree, we could do so as follows, using
the prune.tree() function:
In keeping with the cross-validation results, we use the un-pruned tree
to make predictions on the test set.
yhat=predict(tree.boston,newdata=Boston[-train,])
boston.test=Boston[-train,"medv"]#testing y values
mean((yhat-boston.test)^2) #MSE testing## [1] 35.28688
In other words, the test set MSE associated with the regression tree is 35.28. The square root of the MSE is therefore around 5.93, indicating that this model leads to test predictions that are within around $5,930 of the true median home value for the suburb.
Advantages and Disadvantages of Trees
Decision trees for regression and classification have a number of advantages over the more classical approaches seen in early lessons.
Trees are very easy to explain to people. In fact, they are even easier to explain than linear regression!
Some people believe that decision trees more closely mirror human decision-making than do the regression and classification approaches seen in previous chapters.
Trees can be displayed graphically, and are easily interpreted even by a non-expert (especially if they are small).
Trees can easily handle qualitative predictors without the need to create dummy variables.
Unfortunately, trees generally do not have the same level of predictive accuracy as some of the other regression and classification approaches seen in this book.
Additionally, trees can be very non-robust. In other words, a small change in the data can cause a large change in the final estimated tree.
Random Forests
Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of decision trees on bootstrapped samples. But, when building these decision trees, each time a split in a tree is considered, a random sample of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors. The split is allowed to use only one of those \(m\) predictors. A fresh sample of \(m\) predictors is taken at each split, and typically we choose \(m \cong \sqrt{p}\). In other words, in building a random forest, at each split in the tree, the algorithm is not even allowed to consider a majority of the available predictors. This may sound crazy, but it has a clever rationale.
Suppose that there is one very strong predictor in the data set, along with a number of other moderately strong predictors. Then in the collection of bagged trees, most or all of the trees will use this strong predictor in the top split. Consequently, all of the bagged trees will look quite similar to each other. Hence the predictions from the bagged trees will be highly correlated. Unfortunately, averaging many highly correlated quantities does not lead to as large of a reduction in variance as averaging many uncorrelated quantities. In particular, this means that bagging will not lead to a substantial reduction in variance over a single tree in this setting. Random forests overcome this problem by forcing each split to consider only a subset of the predictors. Therefore, on average \((p-m) / p\) of the splits will not even consider the strong predictor, and so other predictors will have more of a chance. We can think of this process as decorrelating the trees, thereby making the average of the resulting trees less variable and hence more reliable.
Example: Boston Housing Data Cts…
Growing a random forest proceeds in exactly the same way, except that
we use a smaller value of the mtry argument. By default,
randomForest() uses \(p/3\) variables when building a random
forest of regression trees, and \(\sqrt(p)\) variables when building a random
forest of classification trees. Here we use mtry = 6
library(randomForest)
set.seed(1)
rf.boston=randomForest(medv~.,data=Boston,subset=train,mtry=6,importance=TRUE)
yhat.rf = predict(rf.boston,newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2) #test set MSE## [1] 19.62021
Variable importance
## %IncMSE IncNodePurity
## crim 16.697017 1076.08786
## zn 3.625784 88.35342
## indus 4.968621 609.53356
## chas 1.061432 52.21793
## nox 13.518179 709.87339
## rm 32.343305 7857.65451
## age 13.272498 612.21424
## dis 9.032477 714.94674
## rad 2.878434 95.80598
## tax 9.118801 364.92479
## ptratio 8.467062 823.93341
## black 7.579482 275.62272
## lstat 27.129817 6027.63740
Two measures of variable importance are reported. First is the mean
decrease of accuracy in predictions and when the variable values are
permuted. thus breaking the relationship. The latter is a measure of the
total decrease in node impurity that results from splits over that
variable, averaged over all trees . In the case of regression trees, the
node impurity is measured by the training RSS, and for classification
trees by the deviance. Plots of these importance measures can be
produced using the varImpPlot() function
Example: Carseats data set: Classification tree
For this example let’s use the carseats data set.
## Sales CompPrice Income Advertising
## Min. : 0.000 Min. : 77 Min. : 21.00 Min. : 0.000
## 1st Qu.: 5.390 1st Qu.:115 1st Qu.: 42.75 1st Qu.: 0.000
## Median : 7.490 Median :125 Median : 69.00 Median : 5.000
## Mean : 7.496 Mean :125 Mean : 68.66 Mean : 6.635
## 3rd Qu.: 9.320 3rd Qu.:135 3rd Qu.: 91.00 3rd Qu.:12.000
## Max. :16.270 Max. :175 Max. :120.00 Max. :29.000
## Population Price ShelveLoc Age Education
## Min. : 10.0 Min. : 24.0 Bad : 96 Min. :25.00 Min. :10.0
## 1st Qu.:139.0 1st Qu.:100.0 Good : 85 1st Qu.:39.75 1st Qu.:12.0
## Median :272.0 Median :117.0 Medium:219 Median :54.50 Median :14.0
## Mean :264.8 Mean :115.8 Mean :53.32 Mean :13.9
## 3rd Qu.:398.5 3rd Qu.:131.0 3rd Qu.:66.00 3rd Qu.:16.0
## Max. :509.0 Max. :191.0 Max. :80.00 Max. :18.0
## Urban US
## No :118 No :142
## Yes:282 Yes:258
##
##
##
##
Let’s create a binary indicator variable from the continuous variable
## [1] Yes Yes Yes No No Yes No Yes No No Yes Yes No Yes Yes Yes No Yes
## [19] Yes Yes No Yes No No Yes Yes Yes No No No Yes Yes No Yes No Yes
## [37] Yes No No No No No Yes No No No Yes No No Yes No No No No
## [55] No No Yes No No No Yes No No Yes No No Yes Yes Yes No Yes No
## [73] No Yes No Yes Yes No No Yes Yes No Yes No No Yes Yes Yes No No
## [91] No No No Yes Yes No Yes No Yes No No No No No No No No Yes
## [109] No Yes Yes No No No Yes Yes No Yes No No No Yes No Yes Yes Yes
## [127] Yes No No No Yes No Yes No No No No No Yes Yes No No No No
## [145] Yes Yes No Yes No Yes Yes Yes No No No No No Yes Yes Yes No No
## [163] No No Yes No No No No Yes Yes Yes Yes No No No No Yes Yes No
## [181] No No No No Yes Yes Yes No Yes Yes Yes No No Yes No No No No
## [199] No No No No No No Yes No No Yes No No No Yes Yes Yes No No
## [217] No No Yes Yes Yes No No No No No No Yes No Yes No Yes Yes Yes
## [235] Yes No Yes Yes No No Yes Yes No No Yes Yes No No No No Yes No
## [253] Yes No Yes No No Yes No No No No No No No No Yes No No No
## [271] Yes No Yes Yes No No No No No No No Yes No No No No No No
## [289] No Yes Yes No Yes Yes Yes No Yes No Yes Yes Yes No No Yes Yes Yes
## [307] No No Yes Yes Yes No No Yes No No Yes No Yes No No No Yes Yes
## [325] No Yes No No No Yes No Yes No No No No No Yes No Yes No No
## [343] No No Yes No Yes No Yes Yes Yes Yes Yes Yes No No No Yes No No
## [361] Yes Yes No Yes Yes No No Yes Yes Yes No Yes No No Yes No Yes No
## [379] No No Yes No No Yes Yes No No Yes Yes Yes No No No No No Yes
## [397] No No No Yes
## Levels: No Yes
Fit a tree to predict high variable using all variables except sales
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
Training error rate is 9%. see page 325 in ISLR for residual mean deviance (RMD) formula: smaller RMD the better
Plot the tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
By using the above code we can see the fit and prediction at each branch.
If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion (e.g. Price<92.5), the number of observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No. Branches that lead to terminal nodes are indicated using asterisks.
Now, split data into test and train to properly assess performance
set.seed(2)
#indices for 200 observations randomly selected as training samples
train=sample(1:nrow(Carseats), 200)
Carseats.test=Carseats[-train,] #create test data set which contains other observations
High.test=High[-train] #test data sets of only response variablesfit tree on training data by specifying subset=train
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats, subset = train)
## Variables actually used in tree construction:
## [1] "Price" "Population" "ShelveLoc" "Age" "Education"
## [6] "CompPrice" "Advertising" "Income" "US"
## Number of terminal nodes: 21
## Residual mean deviance: 0.5543 = 99.22 / 179
## Misclassification error rate: 0.115 = 23 / 200
Now let’s predict:
## [1] Yes No No Yes No No Yes Yes Yes No No No No Yes Yes No Yes No
## [19] No No No No No No No No No No Yes No No No No No Yes No
## [37] Yes Yes No Yes Yes Yes No No No No No Yes No No No No Yes No
## [55] No No No No No No Yes No No No No No Yes Yes No Yes Yes No
## [73] Yes Yes Yes No No No No Yes Yes Yes No No No Yes Yes No Yes No
## [91] No No No No No Yes No No No No No Yes Yes No No Yes Yes Yes
## [109] Yes No Yes No No Yes No No Yes No No Yes No No No No No No
## [127] No No No No Yes No Yes Yes No Yes No No Yes No Yes Yes No No
## [145] No No Yes No No No No No Yes No Yes No Yes No No Yes No No
## [163] No No No Yes No No Yes No Yes No No No Yes No Yes Yes No No
## [181] No No No No No No Yes No No No Yes No No No No No No No
## [199] Yes No
## Levels: No Yes
Now you can see that, unlike the regression tree we cannot calculate a MSE value for the classification tree. Because we don’t get numerical values as the predictions.
Therefore, we need to use another way to evaluate the model.
Calculating error terms for classification problems
Consider the following example:
This table is called Confusion Matrix:
Define the terms:
True Positives (TP): These are cases in which we predicted yes (they have the disease), and they do have the disease.
True Negatives (TN): We predicted no, and they don’t have the disease.
False Positives (FP): We predicted yes, but they don’t actually have the disease. (Also known as a “Type I error.”)
False Negatives (FN): We predicted no, but they actually do have the disease. (Also known as a “Type II error.”)
This is a list of rates that are often computed from a confusion matrix for a binary classifier:
Accuracy: Overall, how often is the classifier correct?
- (TP+TN)/total = (100+50)/165 = 0.91
Misclassification Rate: Overall, how often is it wrong?
(FP+FN)/total = (10+5)/165 = 0.09
equivalent to ‘1 - Accuracy’
also known as “Error Rate”
True Positive Rate: When it’s actually yes, how often does it predict yes?
TP/actual yes = 100/105 = 0.95
also known as “Sensitivity” or “Recall”
False Positive Rate: When it’s actually no, how often does it predict yes?
- FP/actual no = 10/60 = 0.17
True Negative Rate: When it’s actually no, how often does it predict no?
TN/actual no = 50/60 = 0.83
equivalent to ‘1 - False Positive Rate’
also known as “Specificity”
Precision: When it predicts yes, how often is it correct?
- TP/predicted yes = 100/110 = 0.91
Prevalence: How often does the yes condition actually occur in our sample?
- actual yes/total = 105/165 = 0.64
Example: Carseats data set Cts…
Now let’s calculate the confusion matrix for this.
## High.test
## tree.pred No Yes
## No 104 33
## Yes 13 50
(104+50)/200 correctly predicts 77% of test observations.
Or you can use the caret package to calculate the
confusion matrix.
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 104 33
## Yes 13 50
##
## Accuracy : 0.77
## 95% CI : (0.7054, 0.8264)
## No Information Rate : 0.585
## P-Value [Acc > NIR] : 2.938e-08
##
## Kappa : 0.5091
##
## Mcnemar's Test P-Value : 0.005088
##
## Sensitivity : 0.8889
## Specificity : 0.6024
## Pos Pred Value : 0.7591
## Neg Pred Value : 0.7937
## Prevalence : 0.5850
## Detection Rate : 0.5200
## Detection Prevalence : 0.6850
## Balanced Accuracy : 0.7456
##
## 'Positive' Class : No
##
Now let’s purne the tree.
## [1] "size" "dev" "k" "method"
The function cv.tree() performs cross-validation in
order to cv.tree() determine the optimal level of tree
complexity; cost complexity pruning is used in order to select a
sequence of trees for consideration. We use the argument
FUN=prune.misclass in order to indicate that we want the
classification error rate to guide the cross-validation and pruning
process, rather than the default for the cv.tree()
function, which is deviance. The cv.tree() function reports
the number of terminal nodes of each tree considered (size) as well as
the corresponding error rate and the value of the cost-complexity
parameter used default is 10-fold cross validation.
## $size
## [1] 21 19 14 9 8 5 3 2 1
##
## $dev
## [1] 72 73 72 72 77 77 78 83 84
##
## $k
## [1] -Inf 0.0 1.0 1.4 2.0 3.0 4.0 9.0 18.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 9 terminal nodes results in the lowest cross-validation error rate, with 72 cross-validation errors.
We plot the error rate as a function of both size and k.
par(mfrow=c(1,2))
plot(cv.carseats$size,cv.carseats$dev,type="b")
plot(cv.carseats$k,cv.carseats$dev,type="b")We now apply the prune.misclass() function in order to
prune the tree to obtain the nine-node tree.
prune.carseats=prune.misclass(tree.carseats,best=9)
plot(prune.carseats)
text(prune.carseats,pretty=0)How well does this pruned tree perform on the test data set? Once
again, we apply the predict() function.
## High.test
## tree.pred No Yes
## No 97 25
## Yes 20 58
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 97 25
## Yes 20 58
##
## Accuracy : 0.775
## 95% CI : (0.7108, 0.8309)
## No Information Rate : 0.585
## P-Value [Acc > NIR] : 1.206e-08
##
## Kappa : 0.5325
##
## Mcnemar's Test P-Value : 0.551
##
## Sensitivity : 0.8291
## Specificity : 0.6988
## Pos Pred Value : 0.7951
## Neg Pred Value : 0.7436
## Prevalence : 0.5850
## Detection Rate : 0.4850
## Detection Prevalence : 0.6100
## Balanced Accuracy : 0.7639
##
## 'Positive' Class : No
##
(97+58)/200 #correctly predicts 77.5% of test cases. Now 77.5 % of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.
If we increase the value of best, we obtain a larger pruned tree with lower classification accuracy:
prune.carseats=prune.misclass(tree.carseats,best=8)
plot(prune.carseats)
text(prune.carseats,pretty=0)## High.test
## tree.pred No Yes
## No 89 21
## Yes 28 62
## [1] 0.755
## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 89 21
## Yes 28 62
##
## Accuracy : 0.755
## 95% CI : (0.6894, 0.8129)
## No Information Rate : 0.585
## P-Value [Acc > NIR] : 3.611e-07
##
## Kappa : 0.5015
##
## Mcnemar's Test P-Value : 0.3914
##
## Sensitivity : 0.7607
## Specificity : 0.7470
## Pos Pred Value : 0.8091
## Neg Pred Value : 0.6889
## Prevalence : 0.5850
## Detection Rate : 0.4450
## Detection Prevalence : 0.5500
## Balanced Accuracy : 0.7538
##
## 'Positive' Class : No
##
Now let’s try to fit a random forest model for this data set.
library(randomForest)
set.seed(1)
rf.carseats=randomForest(High~.-Sales,data=Carseats,subset=train,mtry=3,importance=TRUE)
yhat.rf = predict(rf.carseats,newdata=Carseats.test)
confusionMatrix(yhat.rf,High.test)## Confusion Matrix and Statistics
##
## Reference
## Prediction No Yes
## No 110 24
## Yes 7 59
##
## Accuracy : 0.845
## 95% CI : (0.7873, 0.8922)
## No Information Rate : 0.585
## P-Value [Acc > NIR] : 1.939e-15
##
## Kappa : 0.671
##
## Mcnemar's Test P-Value : 0.004057
##
## Sensitivity : 0.9402
## Specificity : 0.7108
## Pos Pred Value : 0.8209
## Neg Pred Value : 0.8939
## Prevalence : 0.5850
## Detection Rate : 0.5500
## Detection Prevalence : 0.6700
## Balanced Accuracy : 0.8255
##
## 'Positive' Class : No
##
Here we can see nuch improved accuracy rates for the Random Forest model.